import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math

from knowledge_tracing.args import ARGS as args
from knowledge_tracing.network.util_network import clones


def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn


class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query,
                                 key,
                                 value,
                                 mask=mask,
                                 dropout=self.dropout,
                                 )

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous() \
            .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_input, d_ff, d_output, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_input, d_ff)
        self.w_2 = nn.Linear(d_ff, d_output)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))


class GELU(nn.Module):
    """
    Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
    """
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)


class Generator(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.fc_1 = nn.Linear(d_model, int(d_model / 2))
        self.fc_2 = nn.Linear(int(d_model / 2), int(d_model / 4))
        self.fc_3 = nn.Linear(int(d_model / 4), 1)

        self.ln_1 = nn.LayerNorm(int(d_model / 2))
        self.ln_2 = nn.LayerNorm(int(d_model / 4))

    def forward(self, x):
        h = self.fc_1(x)
        h = self.ln_1(h)
        h = F.relu(h)

        h = self.fc_2(h)
        h = self.ln_2(h)
        h = F.relu(h)

        h = self.fc_3(h)
        h = torch.sigmoid(h)
        return h


class EmbeddingLayer(nn.Module):
    def __init__(self, input_size, embed_size, padding_idx=None):
        super(EmbeddingLayer, self).__init__()
        self.lut = nn.Embedding(num_embeddings=input_size + 1,
                                embedding_dim=embed_size,
                                padding_idx=padding_idx)
        self.d_model = embed_size

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)


class MultiEmbeddingLayer(nn.Module):
    '''
    A layer for embedding features that may occur with multiplicity
    that varies between examples. E.g. keyword tags.
    We just embed each tag and average the embeddings.
    '''
    def __init__(self,
                 input_size,
                 embed_size,
                 padding_idx=None,
                 collate='average'):
        super().__init__()
        self.embed = EmbeddingLayer(input_size=input_size,
                                    embed_size=embed_size,
                                    padding_idx=padding_idx)
        self.d_model = embed_size
        self.collate = collate
        self.padding_idx = padding_idx

    def forward(self, x):
        # Embed each feature and sum them.
        # x: [batch_size, seq_length, max_tags]
        # [batch_size, seq_length, max_tags, embed_dim]
        embeddings = self.embed(x)
        # Sum
        if self.collate == 'sum':
            output = torch.sum(embeddings, dim=-2)
        # Max
        elif self.collate == 'max':
            embeddings = embeddings.masked_fill(
                x.unsqueeze(dim=-1) == 0, -float('inf'))
            output, _ = torch.max(embeddings, dim=-2)
            output = output.masked_fill(output == float('-inf'), 0)
        # Average
        elif self.collate == 'average':
            nonzero_count = torch.sum(
                x != self.padding_idx, dim=-1, keepdim=True)  # [batch_size, seq_length, 1]
            output = torch.sum(embeddings,
                               dim=-2)  # [batch_size, seq_length, embed_dim]
            output = output / nonzero_count.float()  # [batch_size, seq_length, embed_dim]
            output = output.masked_fill(nonzero_count == 0, 0)
        else:
            raise NotImplementedError(
                f'Collation method {self.collate} not implemented')

        return output


class overkillEmbeddingLayer(nn.Module):
    '''
    ABANDONED
    A layer for embedding features that may ocurr with multiplicity
    that varies between examples. E.g. keyword tags.
    We're going to embed all the tags, concatenate them, then zero out
    the ones that aren't there.
    '''
    def __init__(self, input_size, embed_size, padding_idx=None):
        super().__init__()
        self.embed = EmbeddingLayer(input_size=input_size,
                                    embed_size=embed_size,
                                    padding_idx=padding_idx)
        self.d_model = embed_size
        self.input_size = input_size
        self.linear = nn.Linear(self.d_model * (self.input_size + 1),
                                self.d_model)

    def _overkill_fn(self, x):
        # x: [batch_size, seq_length, max_tags]
        y = torch.tensor(
            [[[tag if tag in q else 0 for tag in range(self.input_size + 1)]
              for q in batch]
             for batch in x])  # [batch_size, seq_length, self.input_size+1]
        return y.to(args.device)

    def forward(self, x):
        # Embed each feature and sum them.
        # TODO: Summing is an unprincipled hack. Do better.
        # x: [batch_size, seq_length, max_tags]
        x = self._overkill_fn(x)  # [batch_size, seq_length, self.input_size+1]
        embeddings = self.embed(
            x)  # [batch_size, seq_length, self.input_size+1, embed_dim]
        output = torch.flatten(
            embeddings, start_dim=-2, end_dim=-1
        )  # [batch_size, seq_length, (self.input_size+1)*embed_dim]
        output = self.linear(output)
        return output


class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)


class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"

    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)


class Encoder(nn.Module):
    def __init__(self, layer, N, input_dim, hidden_dim):
        super(Encoder, self).__init__()
        if args.embed_sum:
            self.layers = clones(layer, N)
        else:
            self.first_layer = nn.Linear(input_dim, hidden_dim)
            self.layers = clones(layer, N)

        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        if not args.embed_sum:
            x = self.first_layer(x)
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)


class Decoder(nn.Module):
    "Generic N layer decoder with masking."

    def __init__(self, layer, N, input_dim, hidden_dim):
        super(Decoder, self).__init__()
        if args.embed_sum:
            self.layers = clones(layer, N)
        else:
            self.first_layer = nn.Linear(input_dim, hidden_dim)
            self.layers = clones(layer, N)

        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        if not args.embed_sum:
            x = self.first_layer(x)
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)


class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many
    other models.
    """
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src_embed, tgt_embed, src_mask, tgt_mask):
        "Take in and process masked src and target sequencoes."
        encoder_output = self.encode(src_embed, src_mask)
        return self.decode(encoder_output, src_mask, tgt_embed,
                           tgt_mask), encoder_output

    def encode(self, src_embed, src_mask):
        return self.encoder(src_embed, src_mask)

    def decode(self, memory, src_mask, tgt_embed, tgt_mask):
        return self.decoder(tgt_embed, memory, src_mask, tgt_mask)


class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)


class TransformerEmbedding(nn.Module):
    def __init__(self,
                 embed_question=None,
                 embed_assessment=None,
                 embed_position=None,
                 embed_start_time=None,
                 embed_elapsed_time=None,
                 dropout=None):
        super().__init__()
        self.embed_list = [
            embed_question, embed_assessment, embed_position, embed_start_time,
            embed_elapsed_time
        ]
        self.dropout = nn.Dropout(dropout)

    def forward(self, embed_input_list):
        embed_output_list = []

        for index in range(len(self.embed_list)):
            embed_input = embed_input_list[index]
            embed_layer = self.embed_list[index]

            if embed_layer is not None and embed_input is not None:
                embed_input = embed_input.long()
                embed_output = embed_layer(embed_input)
                embed_output_list.append(embed_output)

        sum_of_embed_output_list = torch.stack(embed_output_list).sum(dim=0)
        return sum_of_embed_output_list
        # return self.dropout(sum_of_embed_output_list)
